library(discrim)
library(keras3)
library(tensorflow)
library(tidymodels)Building neural network & competitor
Training a deep neural network and a “traditional” competitor model to detect spam emails. The competitor model ended up being a tuned Random Forest, which showed more promising results across 10 folds than other models I had evaluated as potential competitors (Penalized Logistic Regression & Naive Bayes).
The Random Forest showed slightly better results on the test set than the deep neural network.
Packages
Data
General prep
set.seed(42)
spam <- readr::read_csv(here::here("data/spam.csv"))
spam <-
spam |>
mutate(
# outcome has to be ordered factor for tidymodels:
spam = factor(
if_else(spam == 0, "no spam", "spam"),
ordered = TRUE,
levels = c("spam", "no spam")
)
)
# Data split (60/20/20):
spam_split <- initial_validation_split(spam, prop = c(0.6, 0.2), strata = "spam")
train <- training(spam_split)
val <- validation(spam_split)
test <- testing(spam_split)Getting an overview:
glimpse(train)Rows: 2,759
Columns: 58
$ word_freq_make <dbl> 0.00, 0.21, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_address <dbl> 0.64, 0.28, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_all <dbl> 0.64, 0.50, 0.71, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_3d <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_our <dbl> 0.32, 0.14, 1.23, 0.63, 0.63, 1.92, 1.88, 0…
$ word_freq_over <dbl> 0.00, 0.28, 0.19, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_remove <dbl> 0.00, 0.21, 0.19, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_internet <dbl> 0.00, 0.07, 0.12, 0.63, 0.63, 0.00, 1.88, 0…
$ word_freq_order <dbl> 0.00, 0.00, 0.64, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_mail <dbl> 0.00, 0.94, 0.25, 0.63, 0.63, 0.64, 0.00, 0…
$ word_freq_receive <dbl> 0.00, 0.21, 0.38, 0.31, 0.31, 0.96, 0.00, 0…
$ word_freq_will <dbl> 0.64, 0.79, 0.45, 0.31, 0.31, 1.28, 0.00, 0…
$ word_freq_people <dbl> 0.00, 0.65, 0.12, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_report <dbl> 0.00, 0.21, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_addresses <dbl> 0.00, 0.14, 1.75, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_free <dbl> 0.32, 0.14, 0.06, 0.31, 0.31, 0.96, 0.00, 0…
$ word_freq_business <dbl> 0.00, 0.07, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_email <dbl> 1.29, 0.28, 1.03, 0.00, 0.00, 0.32, 0.00, 0…
$ word_freq_you <dbl> 1.93, 3.47, 1.36, 3.18, 3.18, 3.85, 0.00, 1…
$ word_freq_credit <dbl> 0.00, 0.00, 0.32, 0.00, 0.00, 0.00, 0.00, 3…
$ word_freq_your <dbl> 0.96, 1.59, 0.51, 0.31, 0.31, 0.64, 0.00, 2…
$ word_freq_font <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_000 <dbl> 0.00, 0.43, 1.16, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_money <dbl> 0.00, 0.43, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_hp <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_hpl <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_george <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_650 <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_lab <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_labs <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_telnet <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_857 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_data <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_415 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_85 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_technology <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_1999 <dbl> 0.00, 0.07, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_parts <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_pm <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_direct <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_cs <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_meeting <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_original <dbl> 0.00, 0.00, 0.12, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_project <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_re <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_edu <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_table <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_conference <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ `char_freq_;` <dbl> 0.000, 0.000, 0.010, 0.000, 0.000, 0.000, 0…
$ `char_freq_(` <dbl> 0.000, 0.132, 0.143, 0.137, 0.135, 0.054, 0…
$ `char_freq_[` <dbl> 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0…
$ `char_freq_!` <dbl> 0.778, 0.372, 0.276, 0.137, 0.135, 0.164, 0…
$ `char_freq_$` <dbl> 0.000, 0.180, 0.184, 0.000, 0.000, 0.054, 0…
$ `char_freq_#` <dbl> 0.000, 0.048, 0.010, 0.000, 0.000, 0.000, 0…
$ capital_run_length_average <dbl> 3.756, 5.114, 9.821, 3.537, 3.537, 1.671, 2…
$ capital_run_length_longest <dbl> 61, 101, 485, 40, 40, 4, 11, 445, 43, 24, 5…
$ capital_run_length_total <dbl> 278, 1028, 2259, 191, 191, 112, 49, 1257, 7…
$ spam <ord> spam, spam, spam, spam, spam, spam, spam, s…
A lot of these seem to be word or character frequencies, so I suspect that they might be sparse (a lot of zero values) and have skewed distributions. Investigating:
train |>
select(-spam) |>
pivot_longer(cols = everything(), names_to = "Feature", values_to = "Value") |>
ggplot(aes(x = Value)) +
geom_histogram(bins = 20, fill = "steelblue", color = "white") +
facet_wrap(~ Feature, scales = "free") +
theme_minimal()Preprocessing
Below is the preprocessing recipe I used. Synthetic minority class oversampling, to get the same amount of examples for both classes (although class imbalance is not too terrible in this case), log-transforming (mitigate some of the right skew), normalizing, dropping highly correlated features & those with near-zero variance:
spam_rec <-
recipe(spam ~ ., data = train) |>
themis::step_smote(spam, over_ratio = 1, neighbors = 5) |>
step_log(all_numeric_predictors(), offset = 1) |>
step_range(all_numeric_predictors(), min = 0, max = 1) |>
step_corr(all_numeric_predictors(), threshold = 0.9) |>
step_nzv(all_numeric_predictors())The recipe is specified to be fitted (“prepped”) on the training data to avoid leakage.
EDA
Class distribution:
train |>
count(spam) |>
ggplot(aes(x = spam, y = n, color = spam, fill = spam)) +
geom_col(alpha = 0.7) +
theme_minimal() +
scale_color_brewer(palette = "Dark2", direction = -1) +
scale_fill_brewer(palette = "Dark2", direction = -1) +
labs(
title = "Class Distribution",
x = "",
y = "N. of obs"
) +
theme(legend.position = "none")Feature correlations (unlabelled, hard to see with this many features anyways, just to check whether highly correlated features exist):
corrs <-
train |>
select(-spam) |>
cor() |>
as.data.frame() |>
rownames_to_column(var = "x1") |>
tibble() |>
pivot_longer(-x1, names_to = "x2", values_to = "val")
corrs |>
ggplot(aes(x = x1, y = x2, fill = val)) +
geom_tile() +
scale_fill_distiller(palette = "RdYlGn", direction = 1, limits = c(-1, 1)) +
theme(axis.text = element_blank(), axis.ticks = element_blank()) +
labs(
title = "Pairwise correlations",
subtitle = "All features",
fill = "Pearson's r",
x = "",
y = ""
)PCA: checking if we can find some patterns in a transformed representation (also seeing if PCA as part of the preprocessing pipeline might make sense)1:
spam_rec |>
step_pca(all_numeric_predictors(), num_comp = 4) |>
prep() |>
bake(new_data = train) |>
ggplot(aes(x = .panel_x, y = .panel_y, color = spam, fill = spam)) +
geom_point(alpha = 0.25, size = 0.5) +
ggforce::geom_autodensity(alpha = .3) +
ggforce::facet_matrix(vars(-spam), layer.diag = 2) +
scale_color_brewer(palette = "Dark2", direction = -1) +
scale_fill_brewer(palette = "Dark2", direction = -1) +
theme_minimal() +
labs(title = "Principal Component Analysis", fill = "", color = "")Neural Network Classifier
Preparing the data (separate features & labels, bring into matrix format). We also need to apply the fitted preprocessing pipeline here for the data going into keras. prep() fits the recipe (on the training data, we specified this in the recipe), and bake() applies the transformation (equivalent to .fit() and .transform() in sklearn pipelines):
keras_split <- function(set) {
df <-
set |>
mutate(spam = if_else(spam == "spam", 1, 0))
list(
X = df |> select(-spam) |> as.matrix() |> unname(),
y = df |> pull(spam) |> as.matrix()
)
}
keras_train <- spam_rec |> prep() |> bake(new_data = train) |> keras_split()
keras_val <- spam_rec |> prep() |> bake(new_data = val) |> keras_split()
keras_test <- spam_rec |> prep() |> bake(new_data = test) |> keras_split()
X_train <- keras_train$X
y_train <- keras_train$y
X_val <- keras_val$X
y_val <- keras_val$y
X_test <- keras_test$X
y_test <- keras_test$yModel:
keras3::set_random_seed(42)
#^ the keras3 version (supposedly) sets a seed for the R session and the whole backend
mlp <- keras_model_sequential(
layers = list(
layer_dense(units = 128, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
layer_dropout(rate = 0.25),
layer_dense(units = 64, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
layer_dropout(rate = 0.25),
layer_dense(units = 32, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
layer_dropout(rate = 0.25),
layer_dense(units = 16, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
layer_dropout(rate = 0.25),
layer_dense(units = 1, activation = "sigmoid")
)
)Compiling:
keras3::set_random_seed(42)
# chunks are independent when rendering apparently, so again...
mlp |>
compile(
optimizer = optimizer_adam(learning_rate = 0.001),
loss = "binary_crossentropy",
metrics = list(
metric_binary_accuracy(name = "Accuracy"),
metric_precision(name = "Precision"),
metric_recall(name = "Recall"),
metric_f1_score(average = "micro", threshold = .5, name = "F1")
)
)Training:
keras3::set_random_seed(42)
history <-
mlp |>
fit(
x = X_train,
y = y_train,
epochs = 250L,
batch_size = 32L,
validation_data = list(X_val, y_val),
callbacks = list(
# early stopping:
callback_early_stopping(
monitor = "val_loss",
patience = 5L,
restore_best_weights = TRUE
),
# schedule learning rate:
callback_reduce_lr_on_plateau(
monitor = "val_loss",
factor = 0.8,
patience = 3L,
min_lr = 0.00001
)
),
shuffle = FALSE
)Epoch 1/250
87/87 - 6s - 68ms/step - Accuracy: 0.5640 - F1: 0.6408 - Precision: 0.4744 - Recall: 0.9871 - loss: 0.9687 - val_Accuracy: 0.6059 - val_F1: 0.0000e+00 - val_Precision: 0.0000e+00 - val_Recall: 0.0000e+00 - val_loss: 0.8075 - learning_rate: 0.0010
Epoch 2/250
87/87 - 0s - 3ms/step - Accuracy: 0.5143 - F1: 0.2796 - Precision: 0.3364 - Recall: 0.2392 - loss: 0.8138 - val_Accuracy: 0.6243 - val_F1: 0.0895 - val_Precision: 1.0000 - val_Recall: 0.0468 - val_loss: 0.7709 - learning_rate: 0.0010
Epoch 3/250
87/87 - 0s - 3ms/step - Accuracy: 0.6531 - F1: 0.5915 - Precision: 0.5518 - Recall: 0.6375 - loss: 0.7476 - val_Accuracy: 0.6059 - val_F1: 0.0000e+00 - val_Precision: 0.0000e+00 - val_Recall: 0.0000e+00 - val_loss: 0.7395 - learning_rate: 0.0010
Epoch 4/250
87/87 - 0s - 3ms/step - Accuracy: 0.6876 - F1: 0.6009 - Precision: 0.6048 - Recall: 0.5971 - loss: 0.7083 - val_Accuracy: 0.6384 - val_F1: 0.1570 - val_Precision: 0.9688 - val_Recall: 0.0854 - val_loss: 0.6782 - learning_rate: 0.0010
Epoch 5/250
87/87 - 0s - 3ms/step - Accuracy: 0.7746 - F1: 0.7027 - Precision: 0.7313 - Recall: 0.6762 - loss: 0.5927 - val_Accuracy: 0.6384 - val_F1: 0.1570 - val_Precision: 0.9688 - val_Recall: 0.0854 - val_loss: 0.8259 - learning_rate: 0.0010
Epoch 6/250
87/87 - 0s - 4ms/step - Accuracy: 0.7952 - F1: 0.7354 - Precision: 0.7490 - Recall: 0.7222 - loss: 0.5887 - val_Accuracy: 0.8187 - val_F1: 0.7116 - val_Precision: 0.9537 - val_Recall: 0.5675 - val_loss: 0.5009 - learning_rate: 0.0010
Epoch 7/250
87/87 - 0s - 2ms/step - Accuracy: 0.8590 - F1: 0.8152 - Precision: 0.8428 - Recall: 0.7893 - loss: 0.4420 - val_Accuracy: 0.8208 - val_F1: 0.7160 - val_Precision: 0.9541 - val_Recall: 0.5730 - val_loss: 0.5265 - learning_rate: 0.0010
Epoch 8/250
87/87 - 0s - 3ms/step - Accuracy: 0.8731 - F1: 0.8381 - Precision: 0.8428 - Recall: 0.8335 - loss: 0.4301 - val_Accuracy: 0.8469 - val_F1: 0.7700 - val_Precision: 0.9440 - val_Recall: 0.6501 - val_loss: 0.4732 - learning_rate: 0.0010
Epoch 9/250
87/87 - 0s - 3ms/step - Accuracy: 0.8847 - F1: 0.8528 - Precision: 0.8583 - Recall: 0.8473 - loss: 0.4070 - val_Accuracy: 0.8719 - val_F1: 0.8145 - val_Precision: 0.9487 - val_Recall: 0.7135 - val_loss: 0.4453 - learning_rate: 0.0010
Epoch 10/250
87/87 - 0s - 3ms/step - Accuracy: 0.8942 - F1: 0.8656 - Precision: 0.8664 - Recall: 0.8648 - loss: 0.3826 - val_Accuracy: 0.8762 - val_F1: 0.8219 - val_Precision: 0.9495 - val_Recall: 0.7245 - val_loss: 0.4344 - learning_rate: 0.0010
Epoch 11/250
87/87 - 0s - 4ms/step - Accuracy: 0.8949 - F1: 0.8666 - Precision: 0.8666 - Recall: 0.8666 - loss: 0.3866 - val_Accuracy: 0.8806 - val_F1: 0.8297 - val_Precision: 0.9470 - val_Recall: 0.7383 - val_loss: 0.4214 - learning_rate: 0.0010
Epoch 12/250
87/87 - 0s - 3ms/step - Accuracy: 0.8956 - F1: 0.8674 - Precision: 0.8682 - Recall: 0.8666 - loss: 0.3693 - val_Accuracy: 0.8871 - val_F1: 0.8405 - val_Precision: 0.9481 - val_Recall: 0.7548 - val_loss: 0.4229 - learning_rate: 0.0010
Epoch 13/250
87/87 - 0s - 3ms/step - Accuracy: 0.9007 - F1: 0.8742 - Precision: 0.8726 - Recall: 0.8758 - loss: 0.3653 - val_Accuracy: 0.8893 - val_F1: 0.8440 - val_Precision: 0.9485 - val_Recall: 0.7603 - val_loss: 0.4111 - learning_rate: 0.0010
Epoch 14/250
87/87 - 0s - 3ms/step - Accuracy: 0.9007 - F1: 0.8744 - Precision: 0.8712 - Recall: 0.8776 - loss: 0.3633 - val_Accuracy: 0.8893 - val_F1: 0.8445 - val_Precision: 0.9454 - val_Recall: 0.7631 - val_loss: 0.4048 - learning_rate: 0.0010
Epoch 15/250
87/87 - 0s - 4ms/step - Accuracy: 0.9069 - F1: 0.8814 - Precision: 0.8843 - Recall: 0.8786 - loss: 0.3575 - val_Accuracy: 0.8849 - val_F1: 0.8374 - val_Precision: 0.9446 - val_Recall: 0.7521 - val_loss: 0.4057 - learning_rate: 0.0010
Epoch 16/250
87/87 - 0s - 3ms/step - Accuracy: 0.9014 - F1: 0.8741 - Precision: 0.8798 - Recall: 0.8684 - loss: 0.3563 - val_Accuracy: 0.8871 - val_F1: 0.8410 - val_Precision: 0.9450 - val_Recall: 0.7576 - val_loss: 0.3966 - learning_rate: 0.0010
Epoch 17/250
87/87 - 0s - 3ms/step - Accuracy: 0.9032 - F1: 0.8769 - Precision: 0.8789 - Recall: 0.8749 - loss: 0.3451 - val_Accuracy: 0.8882 - val_F1: 0.8427 - val_Precision: 0.9452 - val_Recall: 0.7603 - val_loss: 0.3924 - learning_rate: 0.0010
Epoch 18/250
87/87 - 0s - 3ms/step - Accuracy: 0.9087 - F1: 0.8840 - Precision: 0.8848 - Recall: 0.8832 - loss: 0.3285 - val_Accuracy: 0.8925 - val_F1: 0.8502 - val_Precision: 0.9430 - val_Recall: 0.7741 - val_loss: 0.3971 - learning_rate: 0.0010
Epoch 19/250
87/87 - 0s - 3ms/step - Accuracy: 0.9108 - F1: 0.8867 - Precision: 0.8876 - Recall: 0.8859 - loss: 0.3363 - val_Accuracy: 0.8893 - val_F1: 0.8445 - val_Precision: 0.9454 - val_Recall: 0.7631 - val_loss: 0.3936 - learning_rate: 0.0010
Epoch 20/250
87/87 - 0s - 3ms/step - Accuracy: 0.9134 - F1: 0.8903 - Precision: 0.8883 - Recall: 0.8924 - loss: 0.3328 - val_Accuracy: 0.8817 - val_F1: 0.8315 - val_Precision: 0.9472 - val_Recall: 0.7410 - val_loss: 0.4040 - learning_rate: 0.0010
Epoch 21/250
87/87 - 0s - 3ms/step - Accuracy: 0.9047 - F1: 0.8783 - Precision: 0.8836 - Recall: 0.8730 - loss: 0.3502 - val_Accuracy: 0.9088 - val_F1: 0.8765 - val_Precision: 0.9401 - val_Recall: 0.8209 - val_loss: 0.3405 - learning_rate: 8.0000e-04
Epoch 22/250
87/87 - 0s - 3ms/step - Accuracy: 0.9163 - F1: 0.8933 - Precision: 0.8970 - Recall: 0.8896 - loss: 0.3249 - val_Accuracy: 0.9077 - val_F1: 0.8748 - val_Precision: 0.9399 - val_Recall: 0.8182 - val_loss: 0.3401 - learning_rate: 8.0000e-04
Epoch 23/250
87/87 - 0s - 3ms/step - Accuracy: 0.9134 - F1: 0.8899 - Precision: 0.8911 - Recall: 0.8887 - loss: 0.3303 - val_Accuracy: 0.9034 - val_F1: 0.8678 - val_Precision: 0.9419 - val_Recall: 0.8044 - val_loss: 0.3461 - learning_rate: 8.0000e-04
Epoch 24/250
87/87 - 0s - 3ms/step - Accuracy: 0.9145 - F1: 0.8913 - Precision: 0.8922 - Recall: 0.8905 - loss: 0.3277 - val_Accuracy: 0.8990 - val_F1: 0.8610 - val_Precision: 0.9412 - val_Recall: 0.7934 - val_loss: 0.3511 - learning_rate: 8.0000e-04
Epoch 25/250
87/87 - 0s - 3ms/step - Accuracy: 0.9083 - F1: 0.8830 - Precision: 0.8875 - Recall: 0.8786 - loss: 0.3307 - val_Accuracy: 0.8979 - val_F1: 0.8593 - val_Precision: 0.9410 - val_Recall: 0.7906 - val_loss: 0.3520 - learning_rate: 8.0000e-04
Epoch 26/250
87/87 - 0s - 5ms/step - Accuracy: 0.9065 - F1: 0.8803 - Precision: 0.8877 - Recall: 0.8730 - loss: 0.3384 - val_Accuracy: 0.9110 - val_F1: 0.8798 - val_Precision: 0.9404 - val_Recall: 0.8264 - val_loss: 0.3156 - learning_rate: 6.4000e-04
Epoch 27/250
87/87 - 0s - 3ms/step - Accuracy: 0.9184 - F1: 0.8961 - Precision: 0.8998 - Recall: 0.8924 - loss: 0.3143 - val_Accuracy: 0.9110 - val_F1: 0.8798 - val_Precision: 0.9404 - val_Recall: 0.8264 - val_loss: 0.3128 - learning_rate: 6.4000e-04
Epoch 28/250
87/87 - 0s - 3ms/step - Accuracy: 0.9141 - F1: 0.8912 - Precision: 0.8892 - Recall: 0.8933 - loss: 0.3151 - val_Accuracy: 0.9110 - val_F1: 0.8798 - val_Precision: 0.9404 - val_Recall: 0.8264 - val_loss: 0.3167 - learning_rate: 6.4000e-04
Epoch 29/250
87/87 - 0s - 3ms/step - Accuracy: 0.9181 - F1: 0.8961 - Precision: 0.8953 - Recall: 0.8970 - loss: 0.3152 - val_Accuracy: 0.9066 - val_F1: 0.8732 - val_Precision: 0.9397 - val_Recall: 0.8154 - val_loss: 0.3187 - learning_rate: 6.4000e-04
Epoch 30/250
87/87 - 0s - 4ms/step - Accuracy: 0.9155 - F1: 0.8925 - Precision: 0.8954 - Recall: 0.8896 - loss: 0.3143 - val_Accuracy: 0.9045 - val_F1: 0.8698 - val_Precision: 0.9393 - val_Recall: 0.8099 - val_loss: 0.3198 - learning_rate: 6.4000e-04
Epoch 31/250
87/87 - 0s - 3ms/step - Accuracy: 0.9130 - F1: 0.8892 - Precision: 0.8925 - Recall: 0.8859 - loss: 0.3078 - val_Accuracy: 0.9153 - val_F1: 0.8863 - val_Precision: 0.9412 - val_Recall: 0.8375 - val_loss: 0.3051 - learning_rate: 5.1200e-04
Epoch 32/250
87/87 - 0s - 3ms/step - Accuracy: 0.9137 - F1: 0.8901 - Precision: 0.8934 - Recall: 0.8868 - loss: 0.3040 - val_Accuracy: 0.9142 - val_F1: 0.8847 - val_Precision: 0.9410 - val_Recall: 0.8347 - val_loss: 0.3027 - learning_rate: 5.1200e-04
Epoch 33/250
87/87 - 0s - 3ms/step - Accuracy: 0.9203 - F1: 0.8989 - Precision: 0.8981 - Recall: 0.8997 - loss: 0.3037 - val_Accuracy: 0.9121 - val_F1: 0.8814 - val_Precision: 0.9406 - val_Recall: 0.8292 - val_loss: 0.3027 - learning_rate: 5.1200e-04
Epoch 34/250
87/87 - 0s - 3ms/step - Accuracy: 0.9174 - F1: 0.8947 - Precision: 0.8981 - Recall: 0.8914 - loss: 0.3076 - val_Accuracy: 0.9099 - val_F1: 0.8781 - val_Precision: 0.9403 - val_Recall: 0.8237 - val_loss: 0.3059 - learning_rate: 5.1200e-04
Epoch 35/250
87/87 - 0s - 3ms/step - Accuracy: 0.9152 - F1: 0.8925 - Precision: 0.8916 - Recall: 0.8933 - loss: 0.3018 - val_Accuracy: 0.9110 - val_F1: 0.8798 - val_Precision: 0.9404 - val_Recall: 0.8264 - val_loss: 0.3028 - learning_rate: 5.1200e-04
Epoch 36/250
87/87 - 0s - 3ms/step - Accuracy: 0.9152 - F1: 0.8919 - Precision: 0.8960 - Recall: 0.8878 - loss: 0.3067 - val_Accuracy: 0.9121 - val_F1: 0.8821 - val_Precision: 0.9352 - val_Recall: 0.8347 - val_loss: 0.2926 - learning_rate: 4.0960e-04
Epoch 37/250
87/87 - 0s - 3ms/step - Accuracy: 0.9213 - F1: 0.8996 - Precision: 0.9050 - Recall: 0.8942 - loss: 0.2975 - val_Accuracy: 0.9131 - val_F1: 0.8834 - val_Precision: 0.9381 - val_Recall: 0.8347 - val_loss: 0.2923 - learning_rate: 4.0960e-04
Epoch 38/250
87/87 - 0s - 4ms/step - Accuracy: 0.9206 - F1: 0.8988 - Precision: 0.9033 - Recall: 0.8942 - loss: 0.3011 - val_Accuracy: 0.9121 - val_F1: 0.8821 - val_Precision: 0.9352 - val_Recall: 0.8347 - val_loss: 0.2894 - learning_rate: 4.0960e-04
Epoch 39/250
87/87 - 0s - 4ms/step - Accuracy: 0.9242 - F1: 0.9032 - Precision: 0.9095 - Recall: 0.8970 - loss: 0.2868 - val_Accuracy: 0.9110 - val_F1: 0.8801 - val_Precision: 0.9377 - val_Recall: 0.8292 - val_loss: 0.2952 - learning_rate: 4.0960e-04
Epoch 40/250
87/87 - 0s - 4ms/step - Accuracy: 0.9213 - F1: 0.8998 - Precision: 0.9035 - Recall: 0.8960 - loss: 0.2946 - val_Accuracy: 0.9121 - val_F1: 0.8814 - val_Precision: 0.9406 - val_Recall: 0.8292 - val_loss: 0.2947 - learning_rate: 4.0960e-04
Epoch 41/250
87/87 - 0s - 3ms/step - Accuracy: 0.9232 - F1: 0.9019 - Precision: 0.9070 - Recall: 0.8970 - loss: 0.2952 - val_Accuracy: 0.9121 - val_F1: 0.8814 - val_Precision: 0.9406 - val_Recall: 0.8292 - val_loss: 0.2944 - learning_rate: 4.0960e-04
Epoch 42/250
87/87 - 0s - 3ms/step - Accuracy: 0.9184 - F1: 0.8953 - Precision: 0.9058 - Recall: 0.8850 - loss: 0.2921 - val_Accuracy: 0.9142 - val_F1: 0.8853 - val_Precision: 0.9356 - val_Recall: 0.8402 - val_loss: 0.2852 - learning_rate: 3.2768e-04
Epoch 43/250
87/87 - 0s - 4ms/step - Accuracy: 0.9261 - F1: 0.9055 - Precision: 0.9122 - Recall: 0.8988 - loss: 0.2848 - val_Accuracy: 0.9142 - val_F1: 0.8853 - val_Precision: 0.9356 - val_Recall: 0.8402 - val_loss: 0.2850 - learning_rate: 3.2768e-04
Epoch 44/250
87/87 - 0s - 4ms/step - Accuracy: 0.9253 - F1: 0.9049 - Precision: 0.9082 - Recall: 0.9016 - loss: 0.2879 - val_Accuracy: 0.9164 - val_F1: 0.8886 - val_Precision: 0.9360 - val_Recall: 0.8457 - val_loss: 0.2811 - learning_rate: 3.2768e-04
Epoch 45/250
87/87 - 0s - 4ms/step - Accuracy: 0.9217 - F1: 0.9005 - Precision: 0.9021 - Recall: 0.8988 - loss: 0.2883 - val_Accuracy: 0.9175 - val_F1: 0.8902 - val_Precision: 0.9362 - val_Recall: 0.8485 - val_loss: 0.2809 - learning_rate: 3.2768e-04
Epoch 46/250
87/87 - 0s - 3ms/step - Accuracy: 0.9261 - F1: 0.9057 - Precision: 0.9099 - Recall: 0.9016 - loss: 0.2809 - val_Accuracy: 0.9175 - val_F1: 0.8899 - val_Precision: 0.9388 - val_Recall: 0.8457 - val_loss: 0.2837 - learning_rate: 3.2768e-04
Epoch 47/250
87/87 - 0s - 3ms/step - Accuracy: 0.9250 - F1: 0.9044 - Precision: 0.9082 - Recall: 0.9006 - loss: 0.2829 - val_Accuracy: 0.9131 - val_F1: 0.8834 - val_Precision: 0.9381 - val_Recall: 0.8347 - val_loss: 0.2874 - learning_rate: 3.2768e-04
Epoch 48/250
87/87 - 0s - 3ms/step - Accuracy: 0.9224 - F1: 0.9011 - Precision: 0.9053 - Recall: 0.8970 - loss: 0.2800 - val_Accuracy: 0.9142 - val_F1: 0.8850 - val_Precision: 0.9383 - val_Recall: 0.8375 - val_loss: 0.2862 - learning_rate: 3.2768e-04
Epoch 49/250
87/87 - 0s - 3ms/step - Accuracy: 0.9239 - F1: 0.9027 - Precision: 0.9094 - Recall: 0.8960 - loss: 0.2793 - val_Accuracy: 0.9197 - val_F1: 0.8934 - val_Precision: 0.9366 - val_Recall: 0.8540 - val_loss: 0.2755 - learning_rate: 2.6214e-04
Epoch 50/250
87/87 - 0s - 3ms/step - Accuracy: 0.9286 - F1: 0.9093 - Precision: 0.9105 - Recall: 0.9080 - loss: 0.2806 - val_Accuracy: 0.9186 - val_F1: 0.8918 - val_Precision: 0.9364 - val_Recall: 0.8512 - val_loss: 0.2772 - learning_rate: 2.6214e-04
Epoch 51/250
87/87 - 0s - 4ms/step - Accuracy: 0.9271 - F1: 0.9071 - Precision: 0.9117 - Recall: 0.9025 - loss: 0.2737 - val_Accuracy: 0.9175 - val_F1: 0.8902 - val_Precision: 0.9362 - val_Recall: 0.8485 - val_loss: 0.2773 - learning_rate: 2.6214e-04
Epoch 52/250
87/87 - 0s - 4ms/step - Accuracy: 0.9253 - F1: 0.9047 - Precision: 0.9098 - Recall: 0.8997 - loss: 0.2809 - val_Accuracy: 0.9175 - val_F1: 0.8902 - val_Precision: 0.9362 - val_Recall: 0.8485 - val_loss: 0.2770 - learning_rate: 2.6214e-04
Epoch 53/250
87/87 - 0s - 3ms/step - Accuracy: 0.9290 - F1: 0.9088 - Precision: 0.9191 - Recall: 0.8988 - loss: 0.2768 - val_Accuracy: 0.9186 - val_F1: 0.8918 - val_Precision: 0.9364 - val_Recall: 0.8512 - val_loss: 0.2726 - learning_rate: 2.0972e-04
Epoch 54/250
87/87 - 0s - 4ms/step - Accuracy: 0.9297 - F1: 0.9104 - Precision: 0.9138 - Recall: 0.9071 - loss: 0.2805 - val_Accuracy: 0.9186 - val_F1: 0.8918 - val_Precision: 0.9364 - val_Recall: 0.8512 - val_loss: 0.2733 - learning_rate: 2.0972e-04
Epoch 55/250
87/87 - 0s - 4ms/step - Accuracy: 0.9279 - F1: 0.9078 - Precision: 0.9142 - Recall: 0.9016 - loss: 0.2693 - val_Accuracy: 0.9197 - val_F1: 0.8934 - val_Precision: 0.9366 - val_Recall: 0.8540 - val_loss: 0.2723 - learning_rate: 2.0972e-04
Epoch 56/250
87/87 - 0s - 3ms/step - Accuracy: 0.9257 - F1: 0.9050 - Precision: 0.9121 - Recall: 0.8979 - loss: 0.2782 - val_Accuracy: 0.9207 - val_F1: 0.8950 - val_Precision: 0.9367 - val_Recall: 0.8567 - val_loss: 0.2711 - learning_rate: 2.0972e-04
Epoch 57/250
87/87 - 0s - 3ms/step - Accuracy: 0.9286 - F1: 0.9092 - Precision: 0.9113 - Recall: 0.9071 - loss: 0.2758 - val_Accuracy: 0.9197 - val_F1: 0.8934 - val_Precision: 0.9366 - val_Recall: 0.8540 - val_loss: 0.2715 - learning_rate: 2.0972e-04
Epoch 58/250
87/87 - 0s - 4ms/step - Accuracy: 0.9300 - F1: 0.9107 - Precision: 0.9162 - Recall: 0.9052 - loss: 0.2739 - val_Accuracy: 0.9229 - val_F1: 0.8981 - val_Precision: 0.9371 - val_Recall: 0.8623 - val_loss: 0.2709 - learning_rate: 2.0972e-04
Epoch 59/250
87/87 - 0s - 4ms/step - Accuracy: 0.9297 - F1: 0.9099 - Precision: 0.9185 - Recall: 0.9016 - loss: 0.2721 - val_Accuracy: 0.9207 - val_F1: 0.8950 - val_Precision: 0.9367 - val_Recall: 0.8567 - val_loss: 0.2717 - learning_rate: 2.0972e-04
Epoch 60/250
87/87 - 0s - 3ms/step - Accuracy: 0.9322 - F1: 0.9139 - Precision: 0.9151 - Recall: 0.9126 - loss: 0.2616 - val_Accuracy: 0.9197 - val_F1: 0.8934 - val_Precision: 0.9366 - val_Recall: 0.8540 - val_loss: 0.2722 - learning_rate: 2.0972e-04
Epoch 61/250
87/87 - 0s - 3ms/step - Accuracy: 0.9261 - F1: 0.9056 - Precision: 0.9107 - Recall: 0.9006 - loss: 0.2713 - val_Accuracy: 0.9207 - val_F1: 0.8950 - val_Precision: 0.9367 - val_Recall: 0.8567 - val_loss: 0.2720 - learning_rate: 2.0972e-04
Epoch 62/250
87/87 - 0s - 3ms/step - Accuracy: 0.9297 - F1: 0.9100 - Precision: 0.9177 - Recall: 0.9025 - loss: 0.2703 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2700 - learning_rate: 1.6777e-04
Epoch 63/250
87/87 - 0s - 2ms/step - Accuracy: 0.9308 - F1: 0.9116 - Precision: 0.9171 - Recall: 0.9062 - loss: 0.2676 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2697 - learning_rate: 1.6777e-04
Epoch 64/250
87/87 - 0s - 3ms/step - Accuracy: 0.9293 - F1: 0.9098 - Precision: 0.9145 - Recall: 0.9052 - loss: 0.2707 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2682 - learning_rate: 1.6777e-04
Epoch 65/250
87/87 - 0s - 4ms/step - Accuracy: 0.9293 - F1: 0.9098 - Precision: 0.9145 - Recall: 0.9052 - loss: 0.2754 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2675 - learning_rate: 1.6777e-04
Epoch 66/250
87/87 - 0s - 3ms/step - Accuracy: 0.9293 - F1: 0.9099 - Precision: 0.9137 - Recall: 0.9062 - loss: 0.2634 - val_Accuracy: 0.9251 - val_F1: 0.9013 - val_Precision: 0.9375 - val_Recall: 0.8678 - val_loss: 0.2666 - learning_rate: 1.6777e-04
Epoch 67/250
87/87 - 0s - 4ms/step - Accuracy: 0.9304 - F1: 0.9114 - Precision: 0.9147 - Recall: 0.9080 - loss: 0.2714 - val_Accuracy: 0.9251 - val_F1: 0.9013 - val_Precision: 0.9375 - val_Recall: 0.8678 - val_loss: 0.2669 - learning_rate: 1.6777e-04
Epoch 68/250
87/87 - 0s - 3ms/step - Accuracy: 0.9311 - F1: 0.9119 - Precision: 0.9196 - Recall: 0.9043 - loss: 0.2673 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2675 - learning_rate: 1.6777e-04
Epoch 69/250
87/87 - 0s - 4ms/step - Accuracy: 0.9315 - F1: 0.9124 - Precision: 0.9196 - Recall: 0.9052 - loss: 0.2656 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2677 - learning_rate: 1.6777e-04
Epoch 70/250
87/87 - 0s - 3ms/step - Accuracy: 0.9326 - F1: 0.9134 - Precision: 0.9246 - Recall: 0.9025 - loss: 0.2647 - val_Accuracy: 0.9251 - val_F1: 0.9013 - val_Precision: 0.9375 - val_Recall: 0.8678 - val_loss: 0.2665 - learning_rate: 1.3422e-04
Epoch 71/250
87/87 - 0s - 3ms/step - Accuracy: 0.9286 - F1: 0.9086 - Precision: 0.9167 - Recall: 0.9006 - loss: 0.2701 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2653 - learning_rate: 1.3422e-04
Epoch 72/250
87/87 - 0s - 3ms/step - Accuracy: 0.9275 - F1: 0.9072 - Precision: 0.9149 - Recall: 0.8997 - loss: 0.2623 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2652 - learning_rate: 1.3422e-04
Epoch 73/250
87/87 - 0s - 3ms/step - Accuracy: 0.9297 - F1: 0.9102 - Precision: 0.9161 - Recall: 0.9043 - loss: 0.2586 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2659 - learning_rate: 1.3422e-04
Epoch 74/250
87/87 - 0s - 3ms/step - Accuracy: 0.9337 - F1: 0.9153 - Precision: 0.9209 - Recall: 0.9098 - loss: 0.2656 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2659 - learning_rate: 1.3422e-04
Epoch 75/250
87/87 - 0s - 3ms/step - Accuracy: 0.9333 - F1: 0.9147 - Precision: 0.9224 - Recall: 0.9071 - loss: 0.2622 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2650 - learning_rate: 1.3422e-04
Epoch 76/250
87/87 - 0s - 3ms/step - Accuracy: 0.9344 - F1: 0.9161 - Precision: 0.9234 - Recall: 0.9089 - loss: 0.2581 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2646 - learning_rate: 1.3422e-04
Epoch 77/250
87/87 - 0s - 4ms/step - Accuracy: 0.9308 - F1: 0.9115 - Precision: 0.9187 - Recall: 0.9043 - loss: 0.2654 - val_Accuracy: 0.9273 - val_F1: 0.9044 - val_Precision: 0.9379 - val_Recall: 0.8733 - val_loss: 0.2631 - learning_rate: 1.3422e-04
Epoch 78/250
87/87 - 0s - 3ms/step - Accuracy: 0.9304 - F1: 0.9114 - Precision: 0.9140 - Recall: 0.9089 - loss: 0.2632 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2638 - learning_rate: 1.3422e-04
Epoch 79/250
87/87 - 0s - 4ms/step - Accuracy: 0.9308 - F1: 0.9117 - Precision: 0.9164 - Recall: 0.9071 - loss: 0.2668 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2648 - learning_rate: 1.3422e-04
Epoch 80/250
87/87 - 0s - 4ms/step - Accuracy: 0.9322 - F1: 0.9134 - Precision: 0.9198 - Recall: 0.9071 - loss: 0.2639 - val_Accuracy: 0.9240 - val_F1: 0.8997 - val_Precision: 0.9373 - val_Recall: 0.8650 - val_loss: 0.2645 - learning_rate: 1.3422e-04
Epoch 81/250
87/87 - 0s - 3ms/step - Accuracy: 0.9322 - F1: 0.9132 - Precision: 0.9213 - Recall: 0.9052 - loss: 0.2631 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2616 - learning_rate: 1.0737e-04
Epoch 82/250
87/87 - 0s - 3ms/step - Accuracy: 0.9315 - F1: 0.9123 - Precision: 0.9204 - Recall: 0.9043 - loss: 0.2572 - val_Accuracy: 0.9273 - val_F1: 0.9044 - val_Precision: 0.9379 - val_Recall: 0.8733 - val_loss: 0.2621 - learning_rate: 1.0737e-04
Epoch 83/250
87/87 - 0s - 3ms/step - Accuracy: 0.9304 - F1: 0.9110 - Precision: 0.9178 - Recall: 0.9043 - loss: 0.2602 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2619 - learning_rate: 1.0737e-04
Epoch 84/250
87/87 - 0s - 3ms/step - Accuracy: 0.9319 - F1: 0.9129 - Precision: 0.9197 - Recall: 0.9062 - loss: 0.2598 - val_Accuracy: 0.9251 - val_F1: 0.9013 - val_Precision: 0.9375 - val_Recall: 0.8678 - val_loss: 0.2636 - learning_rate: 1.0737e-04
Epoch 85/250
87/87 - 0s - 5ms/step - Accuracy: 0.9322 - F1: 0.9131 - Precision: 0.9221 - Recall: 0.9043 - loss: 0.2496 - val_Accuracy: 0.9262 - val_F1: 0.9029 - val_Precision: 0.9377 - val_Recall: 0.8705 - val_loss: 0.2617 - learning_rate: 8.5899e-05
Epoch 86/250
87/87 - 0s - 3ms/step - Accuracy: 0.9319 - F1: 0.9130 - Precision: 0.9189 - Recall: 0.9071 - loss: 0.2602 - val_Accuracy: 0.9251 - val_F1: 0.9016 - val_Precision: 0.9349 - val_Recall: 0.8705 - val_loss: 0.2614 - learning_rate: 8.5899e-05
Epoch 87/250
87/87 - 0s - 3ms/step - Accuracy: 0.9348 - F1: 0.9169 - Precision: 0.9203 - Recall: 0.9135 - loss: 0.2549 - val_Accuracy: 0.9251 - val_F1: 0.9016 - val_Precision: 0.9349 - val_Recall: 0.8705 - val_loss: 0.2614 - learning_rate: 8.5899e-05
Epoch 88/250
87/87 - 0s - 3ms/step - Accuracy: 0.9344 - F1: 0.9162 - Precision: 0.9226 - Recall: 0.9098 - loss: 0.2572 - val_Accuracy: 0.9251 - val_F1: 0.9016 - val_Precision: 0.9349 - val_Recall: 0.8705 - val_loss: 0.2608 - learning_rate: 8.5899e-05
Epoch 89/250
87/87 - 0s - 3ms/step - Accuracy: 0.9344 - F1: 0.9163 - Precision: 0.9210 - Recall: 0.9117 - loss: 0.2566 - val_Accuracy: 0.9251 - val_F1: 0.9016 - val_Precision: 0.9349 - val_Recall: 0.8705 - val_loss: 0.2614 - learning_rate: 8.5899e-05
Epoch 90/250
87/87 - 0s - 3ms/step - Accuracy: 0.9333 - F1: 0.9147 - Precision: 0.9224 - Recall: 0.9071 - loss: 0.2623 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2597 - learning_rate: 8.5899e-05
Epoch 91/250
87/87 - 0s - 3ms/step - Accuracy: 0.9333 - F1: 0.9143 - Precision: 0.9255 - Recall: 0.9034 - loss: 0.2538 - val_Accuracy: 0.9251 - val_F1: 0.9016 - val_Precision: 0.9349 - val_Recall: 0.8705 - val_loss: 0.2608 - learning_rate: 8.5899e-05
Epoch 92/250
87/87 - 0s - 3ms/step - Accuracy: 0.9355 - F1: 0.9174 - Precision: 0.9260 - Recall: 0.9089 - loss: 0.2640 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2603 - learning_rate: 8.5899e-05
Epoch 93/250
87/87 - 0s - 3ms/step - Accuracy: 0.9355 - F1: 0.9172 - Precision: 0.9276 - Recall: 0.9071 - loss: 0.2526 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2607 - learning_rate: 8.5899e-05
Epoch 94/250
87/87 - 0s - 2ms/step - Accuracy: 0.9333 - F1: 0.9146 - Precision: 0.9231 - Recall: 0.9062 - loss: 0.2509 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2600 - learning_rate: 6.8719e-05
Epoch 95/250
87/87 - 0s - 3ms/step - Accuracy: 0.9337 - F1: 0.9153 - Precision: 0.9209 - Recall: 0.9098 - loss: 0.2620 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2591 - learning_rate: 6.8719e-05
Epoch 96/250
87/87 - 0s - 3ms/step - Accuracy: 0.9377 - F1: 0.9207 - Precision: 0.9232 - Recall: 0.9181 - loss: 0.2502 - val_Accuracy: 0.9262 - val_F1: 0.9031 - val_Precision: 0.9351 - val_Recall: 0.8733 - val_loss: 0.2597 - learning_rate: 6.8719e-05
Epoch 97/250
87/87 - 0s - 3ms/step - Accuracy: 0.9322 - F1: 0.9135 - Precision: 0.9190 - Recall: 0.9080 - loss: 0.2551 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2589 - learning_rate: 6.8719e-05
Epoch 98/250
87/87 - 0s - 3ms/step - Accuracy: 0.9333 - F1: 0.9147 - Precision: 0.9216 - Recall: 0.9080 - loss: 0.2560 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2584 - learning_rate: 6.8719e-05
Epoch 99/250
87/87 - 0s - 3ms/step - Accuracy: 0.9344 - F1: 0.9163 - Precision: 0.9210 - Recall: 0.9117 - loss: 0.2533 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2587 - learning_rate: 6.8719e-05
Epoch 100/250
87/87 - 0s - 3ms/step - Accuracy: 0.9355 - F1: 0.9180 - Precision: 0.9189 - Recall: 0.9172 - loss: 0.2420 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2600 - learning_rate: 6.8719e-05
Epoch 101/250
87/87 - 0s - 4ms/step - Accuracy: 0.9351 - F1: 0.9168 - Precision: 0.9267 - Recall: 0.9071 - loss: 0.2499 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2602 - learning_rate: 6.8719e-05
Epoch 102/250
87/87 - 0s - 3ms/step - Accuracy: 0.9358 - F1: 0.9178 - Precision: 0.9268 - Recall: 0.9089 - loss: 0.2441 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2592 - learning_rate: 5.4976e-05
Epoch 103/250
87/87 - 0s - 2ms/step - Accuracy: 0.9369 - F1: 0.9193 - Precision: 0.9270 - Recall: 0.9117 - loss: 0.2456 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2584 - learning_rate: 5.4976e-05
Epoch 104/250
87/87 - 0s - 3ms/step - Accuracy: 0.9387 - F1: 0.9216 - Precision: 0.9298 - Recall: 0.9135 - loss: 0.2505 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2588 - learning_rate: 5.4976e-05
Epoch 105/250
87/87 - 0s - 3ms/step - Accuracy: 0.9366 - F1: 0.9190 - Precision: 0.9246 - Recall: 0.9135 - loss: 0.2510 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2588 - learning_rate: 4.3980e-05
Epoch 106/250
87/87 - 0s - 3ms/step - Accuracy: 0.9366 - F1: 0.9189 - Precision: 0.9262 - Recall: 0.9117 - loss: 0.2403 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2590 - learning_rate: 4.3980e-05
Epoch 107/250
87/87 - 0s - 3ms/step - Accuracy: 0.9358 - F1: 0.9180 - Precision: 0.9244 - Recall: 0.9117 - loss: 0.2612 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2582 - learning_rate: 4.3980e-05
Epoch 108/250
87/87 - 0s - 3ms/step - Accuracy: 0.9348 - F1: 0.9164 - Precision: 0.9250 - Recall: 0.9080 - loss: 0.2517 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2579 - learning_rate: 4.3980e-05
Epoch 109/250
87/87 - 0s - 3ms/step - Accuracy: 0.9337 - F1: 0.9152 - Precision: 0.9216 - Recall: 0.9089 - loss: 0.2565 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2579 - learning_rate: 4.3980e-05
Epoch 110/250
87/87 - 0s - 4ms/step - Accuracy: 0.9337 - F1: 0.9146 - Precision: 0.9280 - Recall: 0.9016 - loss: 0.2547 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2578 - learning_rate: 4.3980e-05
Epoch 111/250
87/87 - 0s - 3ms/step - Accuracy: 0.9402 - F1: 0.9236 - Precision: 0.9292 - Recall: 0.9181 - loss: 0.2502 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2581 - learning_rate: 4.3980e-05
Epoch 112/250
87/87 - 0s - 4ms/step - Accuracy: 0.9358 - F1: 0.9176 - Precision: 0.9292 - Recall: 0.9062 - loss: 0.2455 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2580 - learning_rate: 4.3980e-05
Epoch 113/250
87/87 - 0s - 3ms/step - Accuracy: 0.9351 - F1: 0.9166 - Precision: 0.9283 - Recall: 0.9052 - loss: 0.2470 - val_Accuracy: 0.9273 - val_F1: 0.9047 - val_Precision: 0.9353 - val_Recall: 0.8760 - val_loss: 0.2579 - learning_rate: 4.3980e-05
Epoch 114/250
87/87 - 0s - 3ms/step - Accuracy: 0.9391 - F1: 0.9222 - Precision: 0.9290 - Recall: 0.9154 - loss: 0.2426 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2577 - learning_rate: 3.5184e-05
Epoch 115/250
87/87 - 0s - 3ms/step - Accuracy: 0.9395 - F1: 0.9227 - Precision: 0.9283 - Recall: 0.9172 - loss: 0.2404 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2580 - learning_rate: 3.5184e-05
Epoch 116/250
87/87 - 0s - 3ms/step - Accuracy: 0.9369 - F1: 0.9193 - Precision: 0.9270 - Recall: 0.9117 - loss: 0.2488 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2578 - learning_rate: 3.5184e-05
Epoch 117/250
87/87 - 0s - 3ms/step - Accuracy: 0.9344 - F1: 0.9162 - Precision: 0.9218 - Recall: 0.9108 - loss: 0.2481 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2578 - learning_rate: 2.8147e-05
Epoch 118/250
87/87 - 0s - 3ms/step - Accuracy: 0.9351 - F1: 0.9168 - Precision: 0.9267 - Recall: 0.9071 - loss: 0.2486 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2571 - learning_rate: 2.8147e-05
Epoch 119/250
87/87 - 0s - 3ms/step - Accuracy: 0.9369 - F1: 0.9193 - Precision: 0.9270 - Recall: 0.9117 - loss: 0.2422 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2572 - learning_rate: 2.8147e-05
Epoch 120/250
87/87 - 0s - 4ms/step - Accuracy: 0.9387 - F1: 0.9214 - Precision: 0.9322 - Recall: 0.9108 - loss: 0.2464 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2573 - learning_rate: 2.8147e-05
Epoch 121/250
87/87 - 0s - 4ms/step - Accuracy: 0.9358 - F1: 0.9179 - Precision: 0.9252 - Recall: 0.9108 - loss: 0.2450 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2573 - learning_rate: 2.8147e-05
Epoch 122/250
87/87 - 0s - 4ms/step - Accuracy: 0.9369 - F1: 0.9192 - Precision: 0.9278 - Recall: 0.9108 - loss: 0.2499 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2574 - learning_rate: 2.2518e-05
Epoch 123/250
87/87 - 0s - 3ms/step - Accuracy: 0.9373 - F1: 0.9197 - Precision: 0.9279 - Recall: 0.9117 - loss: 0.2463 - val_Accuracy: 0.9273 - val_F1: 0.9050 - val_Precision: 0.9327 - val_Recall: 0.8788 - val_loss: 0.2574 - learning_rate: 2.2518e-05
Details on the trained model:
summary(mlp)Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense) │ (None, 128) │ 3,584 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout (Dropout) │ (None, 128) │ 0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense) │ (None, 64) │ 8,256 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_1 (Dropout) │ (None, 64) │ 0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_2 (Dense) │ (None, 32) │ 2,080 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_2 (Dropout) │ (None, 32) │ 0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_3 (Dense) │ (None, 16) │ 528 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_3 (Dropout) │ (None, 16) │ 0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_4 (Dense) │ (None, 1) │ 17 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
Total params: 43,397 (169.52 KB)
Trainable params: 14,465 (56.50 KB)
Non-trainable params: 0 (0.00 B)
Optimizer params: 28,932 (113.02 KB)
Visually:
mlp |> plot(show_shapes = TRUE, show_trainable = TRUE)Looking at the training progression:
training_prog <-
history |>
as.data.frame() |>
tibble() |>
pivot_wider(values_from = "value", names_from = "metric") |>
drop_na(loss)Loss curves:
training_prog |>
ggplot(aes(x = epoch, y = loss, color = data)) +
geom_line() +
theme_minimal() +
labs(
title = "Training curves",
subtitle = "Binary cross-entropy loss on training and validation sets, over epochs",
x = "Epochs",
y = "Loss",
color = "Data"
)Validation metrics:
training_prog |>
select(-c(learning_rate, loss)) |>
pivot_longer(-c(epoch, data), names_to = "metric", values_to = "value") |>
ggplot(aes(x = epoch, y = value, color = data)) +
geom_line() +
facet_wrap(~metric) +
theme_minimal() +
labs(
title = "Training improvements",
subtitle = "Development of metrics over epochs, validation set",
x = "Epochs",
y = "",
color = "Data"
)Collecting final metrics for training set:
class_metrics <- metric_set(accuracy, precision, recall, f_meas)
mlp_metrics_train <-
mlp$predict(X_train) |>
round() |>
as.vector() |>
tibble(mlp_pred = _) |>
bind_cols(train) |>
mutate(mlp_pred = factor(if_else(mlp_pred == 1, "spam", "no spam"), levels = c("spam", "no spam"))) |>
class_metrics(truth = spam, estimate = mlp_pred) |>
select(-.estimator) |>
pivot_wider(names_from = ".metric", values_from = ".estimate") |>
mutate(name = "Neural Network")
[1m 1/87[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m10s[0m 126ms/step
[1m78/87[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m0s[0m 651us/step
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
[1m87/87[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
Will evaluate it on the test set later along with the competitor, for now I only need the train metrics.
Competitors
10-fold cross validation:
set.seed(42)
folds <- vfold_cv(train, v = 10, strata = "spam")To select a proper competitor, I tune three “traditional” models:
- Penalized logistic regression
- Random Forest
- Naive Bayes
Model specifications:
set.seed(42)
log_spec <- logistic_reg(
mode = "classification",
engine = "glmnet",
penalty = tune(),
mixture = 1 # pure L1 regularization
)
rf_spec <- rand_forest(
mode = "classification",
mtry = tune(),
trees = tune(),
min_n = tune()
) |> set_engine("ranger", importance = "impurity") # variable importance
# For naive bayes, we disable kernel density estimation. This yields
# Gaussian naive bayes, so we assume normally distributed features
nb_spec <- naive_Bayes(
mode = "classification",
smoothness = tune()
) |> set_engine("naivebayes", usekernel = FALSE)Hyperparameter grids (I do regular grid search, given the number of hyperparameters to tune is fairly small; the maximum is three for the random forest, so imo strategies like bayesian tuning don’t really pay off):
log_grid <- tibble(penalty = 10^seq(-5, -1, length.out = 50))
rf_grid <- expand_grid(
mtry = c(4, 8, 12),
trees = c(250, 500, 1000),
min_n = c(5, 10, 20)
)
# just using raw probabilities (smoothing has no effect for Gaussian bayes
# with only numerical features I believe), but also cross-validating
nb_grid <- tibble(smoothness = 0)Tidymodels does have workflow_set and workflow_map for multiple models, but I found them awkward to work with, so just glueing this together myself:
set.seed(42)
competitors <- tribble(
~name, ~spec, ~grid,
"Logistic Regression", log_spec, log_grid,
"Random Forest", rf_spec, rf_grid,
"Naive Bayes", nb_spec, nb_grid
)
competitors <-
competitors |>
mutate(
# bundle recipe & models into workflows:
workflow = map(spec, \(m) workflow() |> add_recipe(spam_rec) |> add_model(m)),
# running grid search for all models:
tuning_res = map2(workflow, grid, function(wf, g) {
tune_grid(
wf,
resamples = folds,
grid = g,
metrics = class_metrics,
control = control_grid(verbose = TRUE, save_pred = TRUE)
)
}),
metrics = map(tuning_res, collect_metrics)
)The workflow-objects are now like a “pipeline” consisting of the preprocessing recipe and then the model (see here).
Looking at the tuning results for the competitor models. I consider precision as the most important metric - false positives (i.e. falsely labeling “ham” emails as spam) are more costly:
competitors |>
select(name, metrics) |>
unnest(metrics) |>
filter(.metric == "precision") |>
arrange(desc(mean)) |>
mutate(
# 95% confidence intervals:
lower = mean - 1.96 * std_err,
upper = mean + 1.96 * std_err
) |>
mutate(rank = row_number()) |>
ggplot(aes(x = rank, y = mean, color = name)) +
geom_point() +
geom_errorbar(aes(ymin = lower, ymax = upper), width = 0.5) +
geom_hline(yintercept = mlp_metrics_train$accuracy, lty = "dashed", color = "grey") +
annotate(
"text",
x = 70,
y = mlp_metrics_train$accuracy + 0.003,
label = "Neural Network",
color = "grey50",
size = 4
) +
theme_minimal() +
labs(
title = "Competitors: Tuning Results",
subtitle = "Precision, estimated across 10 folds",
x = "Model Rank",
y = "Precision",
color = "Model",
caption = "Bars indicate 95% confidence intervals"
) +
theme(legend.position = "bottom")Looks like the random forest has the best chances of performing on par with - or even outperforming - the neural network. Pulling the best hyperparameters:
best_rf_params <-
competitors |>
filter(name == "Random Forest") |>
pull(tuning_res) |>
first() |>
select_best(metric = "precision")
best_rf_params# A tibble: 1 × 4
mtry trees min_n .config
<dbl> <dbl> <dbl> <chr>
1 4 500 5 Preprocessor1_Model04
Fitting the best model to the training set:
rf_fit <-
competitors |>
filter(name == "Random Forest") |>
pull(workflow) |>
first() |>
finalize_workflow(best_rf_params) |>
fit(train)Final Random forest vs. NN on the train set
Collecting train preds & metrics and comparing to neural network:
# Some markdown magic:
mark_best <- function(x) {
map_chr(x, function(val) {
if (val == max(x))
return(paste0("**", as.character(round(val, 3)), "**"))
as.character(round(val, 3))
})
}
rf_fit |>
augment(new_data = train) |>
class_metrics(truth = spam, estimate = .pred_class) |>
pivot_wider(names_from = ".metric", values_from = ".estimate") |>
select(-.estimator) |>
mutate(name = "Random Forest") |>
bind_rows(mlp_metrics_train) |>
select(name, precision, recall, f_meas, accuracy) |>
mutate(across(-name, mark_best)) |>
arrange(name) |>
rename(f1 = f_meas) |>
rename_with(stringr::str_to_title) |>
knitr::kable()| Name | Precision | Recall | F1 | Accuracy |
|---|---|---|---|---|
| Neural Network | 0.935 | 0.919 | 0.927 | 0.943 |
| Random Forest | 0.999 | 0.99 | 0.994 | 0.996 |
At first when I saw those metrics I looked a little like this
This looks like some overfitting is going on, but given that the model does not seem overly complex (e.g. chose 500 over 1000 trees, low number of randomly sampled predictors, both of which should work against model complexity and overfitting), that it was validated across 10 folds with performance rivaling that of the neural network, and that it still delivers pretty good test performance (see below) I concluded that this was still fine.
Random Forest vs. Neural Network on test set
Test predictions for both (plus predicted class probability):
nn_preds <-
mlp |>
predict(X_test) |>
as.vector() |>
tibble(.pred_spam = _) |>
mutate(
.pred_no_spam = 1 - .pred_spam,
.pred_class = round(.pred_spam),
.pred_class = factor(
if_else(.pred_class == 1, "spam", "no spam"),
ordered = TRUE,
levels = c("spam", "no spam")
),
model = "Neural Network"
) |>
bind_cols(test |> select(actual = spam))29/29 - 0s - 4ms/step
rf_preds <-
rf_fit |>
predict(test) |>
bind_cols(rf_fit |> predict(test, type = "prob")) |>
rename(.pred_no_spam = `.pred_no spam`) |>
mutate(
model = "Random Forest",
.pred_class = factor(.pred_class, ordered = TRUE, levels = c("spam", "no spam"))
) |>
bind_cols(test |> select(actual = spam))
test_preds <- bind_rows(nn_preds, rf_preds)Metrics:
test_preds |>
group_by(model) |>
nest(-model) |>
mutate(
metrics = map(data, \(preds) {
preds |>
class_metrics(truth = actual, estimate = .pred_class) |>
select(-.estimator) |>
pivot_wider(names_from = ".metric", values_from = ".estimate")
})
) |>
select(model, metrics) |>
unnest(metrics) |>
rename(f1 = f_meas) |>
ungroup() |>
select(model, precision, recall, f1, accuracy) |>
mutate(across(-model, mark_best)) |>
rename_with(stringr::str_to_title) |>
knitr::kable()| Model | Precision | Recall | F1 | Accuracy |
|---|---|---|---|---|
| Neural Network | 0.932 | 0.912 | 0.922 | 0.939 |
| Random Forest | 0.935 | 0.917 | 0.926 | 0.942 |
Looking at model confidence. Graphically, we can see that the neural network is more confident in its correct predictions, but also overconfident in its wrong predictions:
confidence <-
test_preds |>
mutate(
confidence = if_else(.pred_class == "spam", .pred_spam, .pred_no_spam),
correct = if_else(actual == .pred_class, "correct", "incorrect")
)
confidence |>
ggplot(aes(x = correct, y = confidence, fill = model, color = model)) +
geom_hline(yintercept = c(.5, 1), lty = "dotted", color = "grey50") +
geom_boxplot(position = position_dodge(width = 0.2), width = .1, outliers = FALSE, alpha = .5) +
theme_minimal() +
labs(
title = "Confidence in predictions",
subtitle = "By correct/incorrect classification",
x = "",
y = "Predicted class probability",
fill = "Model",
color = "Model"
)Other way of looking at it (making both & then deciding later for report):
confidence |>
ggplot(aes(x = confidence, color = model, fill = model)) +
geom_density(alpha = .34) +
facet_wrap(~correct, nrow = 2, scale = "free_y") +
theme_minimal() +
labs(
title = "Confidence in predictions",
subtitle = "By correct/incorrect classification",
x = "Predicted class probability",
y = "Density",
fill = "Model",
color = "Model"
) +
theme(aspect.ratio = .5)ROC curves:
test_preds |>
group_by(model) |>
roc_curve(truth = actual, .pred_spam) |>
ungroup() |>
ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
geom_line() +
geom_abline(linetype = "dotted", color = "grey50") +
theme_minimal() +
labs(
title = "ROC Curves",
subtitle = "Neural Network & Random Forest, Test set",
x = "1 - Specificity",
y = "Sensitivity",
color = "Model"
) +
theme(aspect.ratio = 1) # square, easier to tell what's going onConfusion matrices:
test_preds |>
group_by(model) |>
conf_mat(truth = actual, estimate = .pred_class) |>
mutate(plot = map2(model, conf_mat, function(n, c) {
c |>
autoplot(type = "heatmap") +
labs(title = n)
})) |>
pull(plot) |>
patchwork::wrap_plots(ncol = 2) +
theme(aspect.ratio = 1) # make sure they remain squaresRF variable importance:
rf_fit |>
extract_fit_parsnip() |>
vip::vi() |>
slice(1:10) |>
ggplot(aes(x = Importance, y = forcats::fct_reorder(Variable, Importance), color = Variable)) +
geom_segment(aes(xend = 0, yend = Variable), size = 2, alpha = 0.5) +
geom_point(size = 4) +
theme_minimal() +
labs(
title = "Variable Importance",
subtitle = "Random Forest",
x = "Importance (Impurity)",
y = "Feature"
) +
theme(
legend.position = "none",
axis.text.y = element_text(size = 12)
)Bootstrapped confidence intervals
Given we cannot use tune::int_pctl() with the neural network, I am just doing both manually.
keras3::set_random_seed(42)
test_boot <-
bootstraps(test, times = 500, strata = "spam") |>
# no splits, we only want to make predictions:
mutate(data = map(splits, analysis)) |>
select(id, data)
nn_boot <-
test_boot |>
mutate(
metrics_dnn = map(data, function(sample) {
# The model is standalone, not a "workflow",
# so we need to send the data through the prep pipeline
# manually & then convert to matrix format
X <-
spam_rec |>
prep() |>
bake(new_data = sample) |>
select(-spam) |>
as.matrix() |>
unname()
mlp$predict(X, verbose = 0) |>
as.vector() |>
round() |>
tibble(mlp_pred = _) |>
bind_cols(sample) |>
mutate(
mlp_pred = factor(
if_else(mlp_pred == 1, "spam", "no spam"), levels = c("spam", "no spam")
)
) |>
class_metrics(truth = spam, estimate = mlp_pred) |>
select(-.estimator) |>
pivot_wider(names_from = ".metric", values_from = ".estimate") |>
mutate(name = "Neural Network")
})
) |>
unnest(metrics_dnn)
rf_boot <-
test_boot |>
mutate(
metrics_rf = map(data, function(sample) {
# This is a workflow object containing the preprocessing pipeline
# and model that can just take any data directly:
rf_fit |>
augment(new_data = sample) |>
class_metrics(truth = spam, estimate = .pred_class) |>
select(-.estimator) |>
pivot_wider(names_from = ".metric", values_from = ".estimate") |>
mutate(name = "Random Forest")
})
) |>
unnest(metrics_rf)Evaluating:
fns <- list(
mean = mean,
# 95% confidence intervals:
lower = \(x) mean(x) - 1.96 * (sd(x) / sqrt(500)), # 500 = n
upper = \(x) mean(x) + 1.96 * (sd(x) / sqrt(500))
)
boot_res <-
nn_boot |>
select(-c(id, data)) |>
bind_rows(rf_boot |> select(-c(id, data))) |>
pivot_longer(-name, names_to = "metric", values_to = "estimate") |>
group_by(name, metric) |>
summarise(across(estimate, fns, .names = "{fn}")) |>
ungroup()
boot_res# A tibble: 8 × 5
name metric mean lower upper
<chr> <chr> <dbl> <dbl> <dbl>
1 Neural Network accuracy 0.939 0.939 0.940
2 Neural Network f_meas 0.922 0.921 0.923
3 Neural Network precision 0.933 0.932 0.934
4 Neural Network recall 0.912 0.910 0.913
5 Random Forest accuracy 0.943 0.942 0.943
6 Random Forest f_meas 0.926 0.926 0.927
7 Random Forest precision 0.936 0.935 0.937
8 Random Forest recall 0.917 0.916 0.919
Inspecting graphically:
boot_res |>
mutate(metric = if_else(metric == "f_meas", "F1", metric) |> stringr::str_to_title()) |>
ggplot(aes(x = metric, y = mean, color = name)) +
geom_point(position = position_dodge(0.1)) +
geom_errorbar(aes(ymin = lower, ymax = upper), position = position_dodge(0.1), width = .05) +
theme_minimal() +
labs(
title = "Performance on test set",
subtitle = "Bootstrapped 95% confidence intervals",
x = "Metric",
y = "Estimate",
color = ""
)Footnotes
I tried this (it was tempting to just magically generate richt & uncorrelated features instead of dealing with the existing ones, but that would make the neural network more unstable)↩︎